# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

"""
Image Annotation/Search for COCO with Pytorch
"""
from __future__ import absolute_import, division, unicode_literals

import logging
import copy
import numpy as np

import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim


class COCOProjNet(nn.Module):
    def __init__(self, config):
        super(COCOProjNet, self).__init__()
        self.imgdim = config["imgdim"]
        self.sentdim = config["sentdim"]
        self.projdim = config["projdim"]
        self.imgproj = nn.Sequential(
            nn.Linear(self.imgdim, self.projdim),
        )
        self.sentproj = nn.Sequential(
            nn.Linear(self.sentdim, self.projdim),
        )

    def forward(self, img, sent, imgc, sentc):
        # imgc : (bsize, ncontrast, imgdim)
        # sentc : (bsize, ncontrast, sentdim)
        # img : (bsize, imgdim)
        # sent : (bsize, sentdim)
        img = img.unsqueeze(1).expand_as(imgc).contiguous()
        img = img.view(-1, self.imgdim)
        imgc = imgc.view(-1, self.imgdim)
        sent = sent.unsqueeze(1).expand_as(sentc).contiguous()
        sent = sent.view(-1, self.sentdim)
        sentc = sentc.view(-1, self.sentdim)

        imgproj = self.imgproj(img)
        imgproj = imgproj / torch.sqrt(
            torch.pow(imgproj, 2).sum(1, keepdim=True)
        ).expand_as(imgproj)
        imgcproj = self.imgproj(imgc)
        imgcproj = imgcproj / torch.sqrt(
            torch.pow(imgcproj, 2).sum(1, keepdim=True)
        ).expand_as(imgcproj)
        sentproj = self.sentproj(sent)
        sentproj = sentproj / torch.sqrt(
            torch.pow(sentproj, 2).sum(1, keepdim=True)
        ).expand_as(sentproj)
        sentcproj = self.sentproj(sentc)
        sentcproj = sentcproj / torch.sqrt(
            torch.pow(sentcproj, 2).sum(1, keepdim=True)
        ).expand_as(sentcproj)
        # (bsize*ncontrast, projdim)

        anchor1 = torch.sum((imgproj * sentproj), 1)
        anchor2 = torch.sum((sentproj * imgproj), 1)
        img_sentc = torch.sum((imgproj * sentcproj), 1)
        sent_imgc = torch.sum((sentproj * imgcproj), 1)

        # (bsize*ncontrast)
        return anchor1, anchor2, img_sentc, sent_imgc

    def proj_sentence(self, sent):
        output = self.sentproj(sent)
        output = output / torch.sqrt(
            torch.pow(output, 2).sum(1, keepdim=True)
        ).expand_as(output)
        return output  # (bsize, projdim)

    def proj_image(self, img):
        output = self.imgproj(img)
        output = output / torch.sqrt(
            torch.pow(output, 2).sum(1, keepdim=True)
        ).expand_as(output)
        return output  # (bsize, projdim)


class PairwiseRankingLoss(nn.Module):
    """
    Pairwise ranking loss
    """

    def __init__(self, margin):
        super(PairwiseRankingLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor1, anchor2, img_sentc, sent_imgc):

        cost_sent = torch.clamp(self.margin - anchor1 + img_sentc, min=0.0).sum()
        cost_img = torch.clamp(self.margin - anchor2 + sent_imgc, min=0.0).sum()
        loss = cost_sent + cost_img
        return loss


class ImageSentenceRankingPytorch(object):
    # Image Sentence Ranking on COCO with Pytorch
    def __init__(self, train, valid, test, config):
        # fix seed
        self.seed = config["seed"]
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)

        self.train = train
        self.valid = valid
        self.test = test

        self.imgdim = len(train["imgfeat"][0])
        self.sentdim = len(train["sentfeat"][0])
        self.projdim = config["projdim"]
        self.margin = config["margin"]

        self.batch_size = 128
        self.ncontrast = 30
        self.maxepoch = 20
        self.early_stop = True

        config_model = {
            "imgdim": self.imgdim,
            "sentdim": self.sentdim,
            "projdim": self.projdim,
        }
        self.model = COCOProjNet(config_model).cuda()

        self.loss_fn = PairwiseRankingLoss(margin=self.margin).cuda()

        self.optimizer = optim.Adam(self.model.parameters())

    def prepare_data(self, trainTxt, trainImg, devTxt, devImg, testTxt, testImg):
        trainTxt = torch.FloatTensor(trainTxt)
        trainImg = torch.FloatTensor(trainImg)
        devTxt = torch.FloatTensor(devTxt).cuda()
        devImg = torch.FloatTensor(devImg).cuda()
        testTxt = torch.FloatTensor(testTxt).cuda()
        testImg = torch.FloatTensor(testImg).cuda()

        return trainTxt, trainImg, devTxt, devImg, testTxt, testImg

    def run(self):
        self.nepoch = 0
        bestdevscore = -1
        early_stop_count = 0
        stop_train = False

        # Preparing data
        logging.info("prepare data")
        trainTxt, trainImg, devTxt, devImg, testTxt, testImg = self.prepare_data(
            self.train["sentfeat"],
            self.train["imgfeat"],
            self.valid["sentfeat"],
            self.valid["imgfeat"],
            self.test["sentfeat"],
            self.test["imgfeat"],
        )

        # Training
        while not stop_train and self.nepoch <= self.maxepoch:
            logging.info("start epoch")
            self.trainepoch(trainTxt, trainImg, devTxt, devImg, nepoches=1)
            logging.info("Epoch {0} finished".format(self.nepoch))

            results = {
                "i2t": {"r1": 0, "r5": 0, "r10": 0, "medr": 0},
                "t2i": {"r1": 0, "r5": 0, "r10": 0, "medr": 0},
                "dev": bestdevscore,
            }
            score = 0
            for i in range(5):
                devTxt_i = devTxt[i * 5000 : (i + 1) * 5000]
                devImg_i = devImg[i * 5000 : (i + 1) * 5000]
                # Compute dev ranks img2txt
                r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg_i, devTxt_i)
                results["i2t"]["r1"] += r1_i2t / 5
                results["i2t"]["r5"] += r5_i2t / 5
                results["i2t"]["r10"] += r10_i2t / 5
                results["i2t"]["medr"] += medr_i2t / 5
                logging.info(
                    "Image to text: {0}, {1}, {2}, {3}".format(
                        r1_i2t, r5_i2t, r10_i2t, medr_i2t
                    )
                )
                # Compute dev ranks txt2img
                r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg_i, devTxt_i)
                results["t2i"]["r1"] += r1_t2i / 5
                results["t2i"]["r5"] += r5_t2i / 5
                results["t2i"]["r10"] += r10_t2i / 5
                results["t2i"]["medr"] += medr_t2i / 5
                logging.info(
                    "Text to Image: {0}, {1}, {2}, {3}".format(
                        r1_t2i, r5_t2i, r10_t2i, medr_t2i
                    )
                )
                score += (r1_i2t + r5_i2t + r10_i2t + r1_t2i + r5_t2i + r10_t2i) / 5

            logging.info(
                "Dev mean Text to Image: {0}, {1}, {2}, {3}".format(
                    results["t2i"]["r1"],
                    results["t2i"]["r5"],
                    results["t2i"]["r10"],
                    results["t2i"]["medr"],
                )
            )
            logging.info(
                "Dev mean Image to text: {0}, {1}, {2}, {3}".format(
                    results["i2t"]["r1"],
                    results["i2t"]["r5"],
                    results["i2t"]["r10"],
                    results["i2t"]["medr"],
                )
            )

            # early stop on Pearson
            if score > bestdevscore:
                bestdevscore = score
                bestmodel = copy.deepcopy(self.model)
            elif self.early_stop:
                if early_stop_count >= 3:
                    stop_train = True
                early_stop_count += 1
        self.model = bestmodel

        # Compute test for the 5 splits
        results = {
            "i2t": {"r1": 0, "r5": 0, "r10": 0, "medr": 0},
            "t2i": {"r1": 0, "r5": 0, "r10": 0, "medr": 0},
            "dev": bestdevscore,
        }
        for i in range(5):
            testTxt_i = testTxt[i * 5000 : (i + 1) * 5000]
            testImg_i = testImg[i * 5000 : (i + 1) * 5000]
            # Compute test ranks img2txt
            r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(testImg_i, testTxt_i)
            results["i2t"]["r1"] += r1_i2t / 5
            results["i2t"]["r5"] += r5_i2t / 5
            results["i2t"]["r10"] += r10_i2t / 5
            results["i2t"]["medr"] += medr_i2t / 5
            # Compute test ranks txt2img
            r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(testImg_i, testTxt_i)
            results["t2i"]["r1"] += r1_t2i / 5
            results["t2i"]["r5"] += r5_t2i / 5
            results["t2i"]["r10"] += r10_t2i / 5
            results["t2i"]["medr"] += medr_t2i / 5

        return (
            bestdevscore,
            results["i2t"]["r1"],
            results["i2t"]["r5"],
            results["i2t"]["r10"],
            results["i2t"]["medr"],
            results["t2i"]["r1"],
            results["t2i"]["r5"],
            results["t2i"]["r10"],
            results["t2i"]["medr"],
        )

    def trainepoch(self, trainTxt, trainImg, devTxt, devImg, nepoches=1):
        self.model.train()
        for _ in range(self.nepoch, self.nepoch + nepoches):
            permutation = list(np.random.permutation(len(trainTxt)))
            all_costs = []
            for i in range(0, len(trainTxt), self.batch_size):
                # forward
                if i % (self.batch_size * 500) == 0 and i > 0:
                    logging.info("samples : {0}".format(i))
                    r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg, devTxt)
                    logging.info(
                        "Image to text: {0}, {1}, {2}, {3}".format(
                            r1_i2t, r5_i2t, r10_i2t, medr_i2t
                        )
                    )
                    # Compute test ranks txt2img
                    r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg, devTxt)
                    logging.info(
                        "Text to Image: {0}, {1}, {2}, {3}".format(
                            r1_t2i, r5_t2i, r10_t2i, medr_t2i
                        )
                    )
                idx = torch.LongTensor(permutation[i : i + self.batch_size])
                imgbatch = Variable(trainImg.index_select(0, idx)).cuda()
                sentbatch = Variable(trainTxt.index_select(0, idx)).cuda()

                idximgc = np.random.choice(
                    permutation[:i] + permutation[i + self.batch_size :],
                    self.ncontrast * idx.size(0),
                )
                idxsentc = np.random.choice(
                    permutation[:i] + permutation[i + self.batch_size :],
                    self.ncontrast * idx.size(0),
                )
                idximgc = torch.LongTensor(idximgc)
                idxsentc = torch.LongTensor(idxsentc)
                # Get indexes for contrastive images and sentences
                imgcbatch = (
                    Variable(trainImg.index_select(0, idximgc))
                    .view(-1, self.ncontrast, self.imgdim)
                    .cuda()
                )
                sentcbatch = (
                    Variable(trainTxt.index_select(0, idxsentc))
                    .view(-1, self.ncontrast, self.sentdim)
                    .cuda()
                )

                anchor1, anchor2, img_sentc, sent_imgc = self.model(
                    imgbatch, sentbatch, imgcbatch, sentcbatch
                )
                # loss
                loss = self.loss_fn(anchor1, anchor2, img_sentc, sent_imgc)
                all_costs.append(loss.data.item())
                # backward
                self.optimizer.zero_grad()
                loss.backward()
                # Update parameters
                self.optimizer.step()
        self.nepoch += nepoches

    def t2i(self, images, captions):
        """
        Images: (5N, imgdim) matrix of images
        Captions: (5N, sentdim) matrix of captions
        """
        with torch.no_grad():
            # Project images and captions
            img_embed, sent_embed = [], []
            for i in range(0, len(images), self.batch_size):
                img_embed.append(
                    self.model.proj_image(Variable(images[i : i + self.batch_size]))
                )
                sent_embed.append(
                    self.model.proj_sentence(
                        Variable(captions[i : i + self.batch_size])
                    )
                )
            img_embed = torch.cat(img_embed, 0).data
            sent_embed = torch.cat(sent_embed, 0).data

            npts = int(img_embed.size(0) / 5)
            idxs = torch.cuda.LongTensor(range(0, len(img_embed), 5))
            ims = img_embed.index_select(0, idxs)

            ranks = np.zeros(5 * npts)
            for index in range(npts):

                # Get query captions
                queries = sent_embed[5 * index : 5 * index + 5]

                # Compute scores
                scores = torch.mm(queries, ims.transpose(0, 1)).cpu().numpy()
                inds = np.zeros(scores.shape)
                for i in range(len(inds)):
                    inds[i] = np.argsort(scores[i])[::-1]
                    ranks[5 * index + i] = np.where(inds[i] == index)[0][0]

            # Compute metrics
            r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
            r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
            r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
            medr = np.floor(np.median(ranks)) + 1
            return (r1, r5, r10, medr)

    def i2t(self, images, captions):
        """
        Images: (5N, imgdim) matrix of images
        Captions: (5N, sentdim) matrix of captions
        """
        with torch.no_grad():
            # Project images and captions
            img_embed, sent_embed = [], []
            for i in range(0, len(images), self.batch_size):
                img_embed.append(
                    self.model.proj_image(Variable(images[i : i + self.batch_size]))
                )
                sent_embed.append(
                    self.model.proj_sentence(
                        Variable(captions[i : i + self.batch_size])
                    )
                )
            img_embed = torch.cat(img_embed, 0).data
            sent_embed = torch.cat(sent_embed, 0).data

            npts = int(img_embed.size(0) / 5)
            index_list = []

            ranks = np.zeros(npts)
            for index in range(npts):

                # Get query image
                query_img = img_embed[5 * index]

                # Compute scores
                scores = torch.mm(
                    query_img.view(1, -1), sent_embed.transpose(0, 1)
                ).view(-1)
                scores = scores.cpu().numpy()
                inds = np.argsort(scores)[::-1]
                index_list.append(inds[0])

                # Score
                rank = 1e20
                for i in range(5 * index, 5 * index + 5, 1):
                    tmp = np.where(inds == i)[0][0]
                    if tmp < rank:
                        rank = tmp
                ranks[index] = rank

            # Compute metrics
            r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
            r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
            r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
            medr = np.floor(np.median(ranks)) + 1
            return (r1, r5, r10, medr)
